{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SparseSinkhorn Solver: Barycenters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from lib.header_notebook import *\n",
    "import Solvers.Sinkhorn as Sinkhorn\n",
    "import Solvers.Sinkhorn.Barycenter as Barycenter\n",
    "import lib.header_params_Sinkhorn\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# verbosity: disable iteration output, may become really slow in notebooks\n",
    "paramsVerbose={\n",
    "        \"solve_overview\":True,\\\n",
    "        \"solve_update\":True,\\\n",
    "        \"solve_kernel\":True,\\\n",
    "        \"solve_iterate\":False\\\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params=lib.header_params_Sinkhorn.getParamsDefaultBarycenter()\n",
    "paramsListCommandLine,paramsListCFGFile=lib.header_params_Sinkhorn.getParamListsBarycenter()\n",
    "\n",
    "params[\"setup_tag\"]=\"cfg/Sinkhorn/Barycenter/OT_simple_1-2-1\"\n",
    "#params[\"setup_tag\"]=\"cfg/Sinkhorn/Barycenter/WF_three_1-2-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params[\"setup_cfgfile\"]=params[\"setup_tag\"]+\".txt\"\n",
    "\n",
    "# load parameters from config file\n",
    "params.update(ScriptTools.readParameters(params[\"setup_cfgfile\"],paramsListCFGFile))\n",
    "\n",
    "# interpreting some parameters\n",
    "\n",
    "# totalMass regulates, whether marginals should be normalized or not.\n",
    "if params[\"setup_totalMass\"]<0:\n",
    "    params[\"setup_totalMass\"]=None\n",
    "# finest level for multi-scale algorithm\n",
    "params[\"hierarchy_lBottom\"]=params[\"hierarchy_depth\"]+1\n",
    "\n",
    "\n",
    "print(\"final parameter settings\")\n",
    "for k in sorted(params.keys()):\n",
    "    print(\"\\t\",k,params[k])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def loadProblem(filename):\n",
    "    img=sciio.loadmat(filename)[\"a\"]\n",
    "    return img\n",
    "\n",
    "def setupDensity(img,constOffset=0,posScale=1.,keepZero=False,totalMass=None):\n",
    "    (mu,pos)=OTTools.processDensity_Grid(img,totalMass=totalMass,constOffset=constOffset,keepZero=keepZero)\n",
    "    pos=pos/posScale\n",
    "    return (mu,pos,img.shape)\n",
    "\n",
    "problemData=[setupDensity(loadProblem(filename),\\\n",
    "        posScale=params[\"setup_posScale\"],constOffset=params[\"setup_constOffset\"],\\\n",
    "        keepZero=False,totalMass=params[\"setup_totalMass\"]\\\n",
    "        )\\\n",
    "        for filename in params[\"setup_fileList\"]]\n",
    "\n",
    "problemDataCenter=setupDensity(np.full(params[\"setup_centerRes\"],1.,dtype=np.double),posScale=params[\"setup_posScale\"])\n",
    "nProblems=len(problemData)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# visualize marginals\n",
    "fig=plt.figure(figsize=(4*nProblems,4))\n",
    "for i in range(nProblems):\n",
    "    img=OTTools.ProjectInterpolation2D(problemData[i][1],problemData[i][0],problemData[i][2][0],problemData[i][2][1])\n",
    "    img=img.toarray()\n",
    "    fig.add_subplot(1,nProblems,i+1)\n",
    "    plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# set up eps-scaling\n",
    "\n",
    "# geometric scaling from eps_start to eps_target in eps_steps+1 steps\n",
    "params.update(\n",
    "        Sinkhorn.Aux.SetupEpsScaling_Geometric(params[\"eps_target\"],params[\"eps_start\"],params[\"eps_steps\"],\\\n",
    "        verbose=True))\n",
    "\n",
    "# determine finest epsilon for each hierarchy level.\n",
    "# on coarsest level it is given by params[\"eps_boxScale\"]**params[\"eps_boxScale_power\"]\n",
    "# with each level, the finest scale params[\"eps_boxScale\"] is effectively divided by 2\n",
    "params[\"eps_scales\"]=[(params[\"eps_boxScale\"]/(2**n))**params[\"eps_boxScale_power\"]\\\n",
    "        for n in range(params[\"hierarchy_depth\"]+1)]+[0]\n",
    "print(\"eps_scales:\\t\",params[\"eps_scales\"])\n",
    "\n",
    "# divide eps_list into eps_lists, one for each hierarchy scale, divisions determined by eps_scales.\n",
    "params.update(Sinkhorn.Aux.SetupEpsScaling_Scales(params[\"eps_list\"],params[\"eps_scales\"],\\\n",
    "        levelTop=params[\"hierarchy_lTop\"], nOverlap=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# construct hierarchical representation\n",
    "partitionChildMode=HierarchicalPartition.THPMode_Tree\n",
    "\n",
    "# constructing basic partitions\n",
    "partitionList=[]\n",
    "for i in range(nProblems):\n",
    "    partition=HierarchicalPartition.GetPartition(problemData[i][1],params[\"hierarchy_depth\"],partitionChildMode,\\\n",
    "            box=None, signal_pos=True, signal_radii=True,clib=SolverCFC, export=False, verbose=False,\\\n",
    "            finestDimsWarning=False)\n",
    "    partitionList.append(partition)\n",
    "\n",
    "partitionCenter=HierarchicalPartition.GetPartition(problemDataCenter[1],params[\"hierarchy_depth\"],partitionChildMode,\\\n",
    "            box=None, signal_pos=True, signal_radii=True,clib=SolverCFC, export=False, verbose=False,\\\n",
    "            finestDimsWarning=False)\n",
    "\n",
    "# exporting partitions\n",
    "pointerListPartition=np.zeros((nProblems),dtype=np.int64)\n",
    "for i in range(nProblems):\n",
    "    pointerListPartition[i]=SolverCFC.Export(partitionList[i])\n",
    "\n",
    "pointerPartitionCenter=SolverCFC.Export(partitionCenter)\n",
    "\n",
    "muHList=[SolverCFC.GetSignalMass(pointer,partition,aprob[0])\n",
    "        for pointer,partition,aprob in zip(pointerListPartition,partitionList,problemData)]\n",
    "\n",
    "muHCenterList=SolverCFC.GetSignalMass(pointerPartitionCenter,partitionCenter,problemDataCenter[0])\n",
    "\n",
    "# pointer lists\n",
    "pointerPosList=[HierarchicalPartition.getSignalPointer(partition,\"pos\") for partition in partitionList]\n",
    "pointerRadiiList=[HierarchicalPartition.getSignalPointer(partition,\"radii\",lBottom=partition.nlayers-2)\n",
    "        for partition in partitionList]\n",
    "\n",
    "pointerPosCenter=HierarchicalPartition.getSignalPointer(partitionCenter,\"pos\")\n",
    "pointerRadiiCenter=HierarchicalPartition.getSignalPointer(partitionCenter,\"radii\",lBottom=partition.nlayers-2)\n",
    "\n",
    "pointerListPos=np.array([pointerPos.ctypes.data for pointerPos in pointerPosList],dtype=np.int64)\n",
    "pointerListRadii=np.array([pointerRadii.ctypes.data for pointerRadii in pointerRadiiList],dtype=np.int64)\n",
    "\n",
    "\n",
    "# pairwise pointer lists\n",
    "pointerListPosPairs=[np.array([pointerListPos[i],pointerPosCenter.ctypes.data],dtype=np.int64)\n",
    "        for i in range(nProblems)]\n",
    "pointerListRadiiPairs=[np.array([pointerListRadii[i],pointerRadiiCenter.ctypes.data],dtype=np.int64)\n",
    "        for i in range(nProblems)]\n",
    "pointerListPairsPartition=[np.array([pointerListPartition[i],pointerPartitionCenter],dtype=np.int64)\n",
    "        for i in range(nProblems)]\n",
    "\n",
    "# print a few stats on the created problem\n",
    "for i,partition in enumerate(partitionList):\n",
    "    print(\"cells in partition {:d}: \".format(i),partition.cardLayers)\n",
    "print(\"cells in central partition: \",partitionCenter.cardLayers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solver Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "if params[\"model_transportModel\"]==\"ot\":\n",
    "    import Solvers.Sinkhorn.Models.BarycenterOT as ModelBarycenterOT\n",
    "\n",
    "    def get_method_CostFunctionProviderPair(index):\n",
    "        return lambda level, pointerAlpha, alphaFinest=None :\\\n",
    "                Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclidean(pointerListPosPairs[index],\\\n",
    "                        partitionList[0].ndim,level,pointerListRadiiPairs[index],pointerAlpha,alphaFinest\\\n",
    "                        )\n",
    "\n",
    "    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\\\n",
    "            eps, nInnerIterations: \\\n",
    "                    ModelBarycenterOT.Iterate(kernel,alphaList,scalingList,muList,eps,nInnerIterations,\\\n",
    "                            params[\"setup_weightList\"],zeroHandling=True,setZeroInf=True)\n",
    "\n",
    "    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps:\\\n",
    "            ModelBarycenterOT.ErrorMarginLInf(kernel,scalingList,muList,zeroHandling=True)\n",
    "\n",
    "elif params[\"model_transportModel\"]==\"wf\":\n",
    "    import Solvers.Sinkhorn.Models.BarycenterWF as ModelBarycenterWF\n",
    "\n",
    "    def get_method_CostFunctionProviderPair(index):\n",
    "        return lambda level, pointerAlpha, alphaFinest=None :\\\n",
    "                Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclideanWF(pointerListPosPairs[index],\\\n",
    "                        partitionList[0].ndim,level,pointerListRadiiPairs[index],pointerAlpha,alphaFinest,\\\n",
    "                        FR_kappa=params[\"model_FR_kappa\"])\n",
    "\n",
    "    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\\\n",
    "            eps, nInnerIterations: \\\n",
    "            ModelBarycenterWF.Iterate(kernel,alphaList,scalingList,muList,eps,nInnerIterations,\\\n",
    "                    params[\"setup_weightList\"], params[\"model_FR_kappa\"],zeroHandling=True)\n",
    "\n",
    "    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps:\\\n",
    "            ModelBarycenterWF.ScorePDGap(kernel,alphaList,scalingList,muList,eps,\\\n",
    "                    params[\"setup_weightList\"],params[\"model_FR_kappa\"])\n",
    "\n",
    "elif params[\"model_transportModel\"]==\"ghk\":\n",
    "    import Solvers.Sinkhorn.Models.BarycenterWF as ModelBarycenterWF\n",
    "\n",
    "    def get_method_CostFunctionProviderPair(index):\n",
    "        return lambda level, pointerAlpha, alphaFinest=None :\\\n",
    "                Sinkhorn.CInterface.Setup_CostFunctionProvider_SquaredEuclidean(pointerListPosPairs[index],\\\n",
    "                        partitionList[0].ndim,level,pointerListRadiiPairs[index],pointerAlpha,alphaFinest\\\n",
    "                        )\n",
    "\n",
    "    method_iterate_iterate = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu,\\\n",
    "            eps, nInnerIterations: \\\n",
    "            ModelBarycenterWF.Iterate(kernel,alphaList,scalingList,muList,eps,nInnerIterations,\\\n",
    "                    params[\"setup_weightList\"], params[\"model_FR_kappa\"],zeroHandling=True)\n",
    "\n",
    "    method_iterate_error = lambda kernel, alphaList, scalingList, muList, pointerListScaling, pointerListMu, eps:\\\n",
    "            ModelBarycenterWF.ScorePDGap(kernel,alphaList,scalingList,muList,eps,\\\n",
    "                    params[\"setup_weightList\"],params[\"model_FR_kappa\"])\n",
    "\n",
    "\n",
    "\n",
    "else:\n",
    "    raise ValueError(\"model_transportModel not recognized: \"+params[\"model_transportModel\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "get_method_getKernel = lambda level, muList, muCenter:\\\n",
    "        lambda kernel, alpha, eps:\\\n",
    "                Barycenter.GetKernel_SparseCSR(partitionList,pointerListPartition,\\\n",
    "                        partitionCenter,pointerPartitionCenter,\\\n",
    "                        get_method_CostFunctionProviderPair,\\\n",
    "                        level, alpha, eps, muList, muCenter,\\\n",
    "                        kThresh=params[\"sparsity_kThresh\"],\\\n",
    "                        verbose=paramsVerbose[\"solve_kernel\"])\n",
    "\n",
    "\n",
    "method_deleteKernel = lambda kernel : None\n",
    "\n",
    "\n",
    "method_refineKernel = lambda level, kernel, alphaList, muList, muCenter, eps:\\\n",
    "        Barycenter.RefineKernel_CSR(partitionList, pointerListPartition, partitionCenter, pointerPartitionCenter,\\\n",
    "                get_method_CostFunctionProviderPair,\\\n",
    "                level, kernel, alphaList, eps, muList, muCenter,\\\n",
    "                verbose=paramsVerbose[\"solve_kernel\"])\n",
    "\n",
    "method_getKernelVariablesCount=Barycenter.GetKernelVariablesCount_CSR\n",
    "\n",
    "method_absorbScaling = lambda alphaList,scalingList,eps:\\\n",
    "                Sinkhorn.Method_AbsorbScalings(alphaList,scalingList,eps,\\\n",
    "                        residualScaling=None,minAlpha=[-1E5 for i in range(2*nProblems)])\n",
    "\n",
    "\n",
    "get_method_update = lambda epsList, method_getKernel, method_deleteKernel, method_absorbScaling:\\\n",
    "        lambda status, data:\\\n",
    "                Sinkhorn.Update(status,data,epsList,\\\n",
    "                        method_getKernel, method_deleteKernel, method_absorbScaling,\\\n",
    "                        verbose=paramsVerbose[\"solve_update\"],absorbFinalIteration=True,maxRepeats=20)\n",
    "\n",
    "\n",
    "method_iterate = lambda status,data : Sinkhorn.Method_IterateToPrecision(status,data,\\\n",
    "                method_iterate=method_iterate_iterate,\\\n",
    "                method_error=method_iterate_error,\\\n",
    "                maxError=params[\"sinkhorn_error\"],\\\n",
    "                nInnerIterations=params[\"sinkhorn_nInner\"],maxOuterIterations=params[\"sinkhorn_maxOuter\"],\\\n",
    "                scalingBound=params[\"adaption_scalingBound\"],scalingLowerBound=params[\"adaption_scalingLowerBound\"],\\\n",
    "                verbose=paramsVerbose[\"solve_iterate\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "result=Barycenter.MultiscaleSolver(partitionList,pointerListPartition,\\\n",
    "        partitionCenter, pointerPartitionCenter,\\\n",
    "        muHList,muHCenterList,\\\n",
    "        params[\"eps_lists\"],\\\n",
    "        get_method_getKernel,method_deleteKernel,method_absorbScaling,\\\n",
    "        method_iterate,get_method_update,method_refineKernel,\\\n",
    "        levelTop=params[\"hierarchy_lTop\"],levelBottom=params[\"hierarchy_lBottom\"],\\\n",
    "        verbose=paramsVerbose[\"solve_overview\"],\\\n",
    "        collectReports=True,method_getKernelVariablesCount=method_getKernelVariablesCount\\\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data=result[\"data\"]\n",
    "status=result[\"status\"]\n",
    "setup=result[\"setup\"]\n",
    "setupAux=result[\"setupAux\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data[\"kernel\"]=setupAux[\"method_getKernel\"](data[\"kernel\"],data[\"alpha\"],data[\"eps\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# clean up hieararchical partitions\n",
    "for pointer in pointerListPartition:\n",
    "    SolverCFC.Close(pointer)\n",
    "    \n",
    "SolverCFC.Close(pointerPartitionCenter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Post-Processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# extract all marginals\n",
    "# two lists. first list: marginals for the reference measures, second list: marginals for the common central barycenter\n",
    "margs=Barycenter.GetMarginals(data[\"kernel\"],data[\"scaling\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# check marginal error / primal-dual gap\n",
    "if params[\"model_transportModel\"]==\"ot\":\n",
    "    \n",
    "    for f,v in zip([\"error\"],[\\\n",
    "            method_iterate_error(data[\"kernel\"],data[\"alpha\"],data[\"scaling\"],data[\"mu\"],None,None,data[\"eps\"])\\\n",
    "           ]):\n",
    "           print(f,\" : \",v)\n",
    "\n",
    "elif params[\"model_transportModel\"] in [\"wf\",\"ghk\"]:\n",
    "\n",
    "    for f,v in zip([\"PD gap\",\"primal\",\"dual\"],[\\\n",
    "            method_iterate_error(data[\"kernel\"],data[\"alpha\"],data[\"scaling\"],data[\"mu\"],None,None,data[\"eps\"]),\\\n",
    "            ModelBarycenterWF.ScorePrimal(data[\"kernel\"],data[\"alpha\"],data[\"scaling\"],data[\"mu\"],data[\"eps\"],\\\n",
    "                    params[\"setup_weightList\"],params[\"model_FR_kappa\"]),\\\n",
    "            ModelBarycenterWF.ScoreDual(data[\"kernel\"],data[\"alpha\"],data[\"scaling\"],data[\"mu\"],data[\"eps\"],\\\n",
    "                    params[\"setup_weightList\"],params[\"model_FR_kappa\"])\\\n",
    "           ]):\n",
    "           print(f,\" : \",v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# compute barycenter from central marginals (weighted average)\n",
    "img=np.zeros(problemDataCenter[2],dtype=np.double)\n",
    "for i in range(nProblems):\n",
    "    img[...]+=params[\"setup_weightList\"][i]*margs[1][i].reshape(problemDataCenter[2])\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
